// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "test/common/AcceleratorHarness.h"

#include <mc_connections.h>
#include <systemc.h>

#include "src/AccelTypes.h"

AcceleratorHarness::AcceleratorHarness(sc_module_name name, Params params,
                                       VectorParams vectorParams,
                                       INPUT_DATATYPE *memory)
    : sc_module(name),
      clk("clk", 1, SC_NS, 0.5, 0, SC_NS, true),
      params(params),
      vectorParams(vectorParams),
      mainMemory(memory) {
  accelerator.clk(clk);
  accelerator.rstn(rstn);
  accelerator.paramsIn(paramsIn);
  accelerator.vectorParamsIn(vectorParamsIn);

  for (int i = 0; i < NUM_CORES; i++) {
    accelerator.inputAddressRequest[i](inputAddressRequest[i]);
    accelerator.inputDataResponse[i](inputDataResponse[i]);
    accelerator.weightAddressRequest[i](weightAddressRequest[i]);
    accelerator.weightDataResponse[i](weightDataResponse[i]);
    accelerator.vectorFetchAddressRequest[i](vectorFetchAddressRequest[i]);
    accelerator.vectorFetchDataResponse[i](vectorFetchDataResponse[i]);
    accelerator.vectorUnitOutput[i](vectorUnitOutput[i]);
    accelerator.outputAddress[i](outputAddress[i]);
    accelerator.done[i](done[i]);
  }

  SC_CTHREAD(reset, clk);

  SC_THREAD(memoryAccessInputs);
  sensitive << clk.posedge_event();
  async_reset_signal_is(rstn, false);

  SC_THREAD(memoryAccessWeights);
  sensitive << clk.posedge_event();
  async_reset_signal_is(rstn, false);

  SC_THREAD(memoryAccessVector);
  sensitive << clk.posedge_event();
  async_reset_signal_is(rstn, false);

  SC_THREAD(storeOutputs);
  sensitive << clk.posedge_event();
  async_reset_signal_is(rstn, false);

  SC_THREAD(sendParams);
  sensitive << clk.posedge_event();
  async_reset_signal_is(rstn, false);

  SC_THREAD(sendVectorParams);
  sensitive << clk.posedge_event();
  async_reset_signal_is(rstn, false);

  SC_THREAD(waitForDone);
  sensitive << clk.posedge_event();
  async_reset_signal_is(rstn, false);
}

void AcceleratorHarness::reset() {
  rstn.write(0);
  wait(5);
  rstn.write(1);
  wait();
}

void AcceleratorHarness::memAccessBurst(
    Connections::Combinational<MemoryRequest> addressRequest[NUM_CORES],
    Connections::Combinational<Pack1D<INPUT_DATATYPE, DIMENSION> >
        dataResponse[NUM_CORES]) {
  Pack1D<INPUT_DATATYPE, DIMENSION> data[NUM_CORES];
  int filledData[NUM_CORES];

  for (int i = 0; i < NUM_CORES; i++) {
    addressRequest[i].ResetRead();
    dataResponse[i].ResetWrite();
    filledData[i] = 0;
  }

  wait();

  while (true) {
    for (int i = 0; i < NUM_CORES; i++) {
      MemoryRequest memRequest;
      if (addressRequest[i].PopNB(memRequest)) {
        for (int b = 0; b < memRequest.burstSize; b++) {
          data[i].value[filledData[i]] = mainMemory[memRequest.address + b];
          filledData[i]++;

          if (filledData[i] == DIMENSION) {
            filledData[i] = 0;
            dataResponse[i].Push(data[i]);
          }
        }
      }
    }
    wait();
  }
}

void AcceleratorHarness::memAccessPack(
    Connections::Combinational<int> addressRequest[NUM_CORES],
    Connections::Combinational<Pack1D<INPUT_DATATYPE, DIMENSION> >
        dataResponse[NUM_CORES]) {
  Pack1D<INPUT_DATATYPE, DIMENSION> data[NUM_CORES];
  int filledData[NUM_CORES];

  for (int i = 0; i < NUM_CORES; i++) {
    addressRequest[i].ResetRead();
    dataResponse[i].ResetWrite();
    filledData[i] = 0;
  }

  wait();

  while (true) {
    for (int i = 0; i < NUM_CORES; i++) {
      int address;
      if (addressRequest[i].PopNB(address)) {
        data[i].value[filledData[i]] = mainMemory[address];
        filledData[i]++;
        if (filledData[i] == DIMENSION) {
          filledData[i] = 0;
          dataResponse[i].Push(data[i]);
        }
      }
    }
    wait();
  }
}

void AcceleratorHarness::memAccess(
    Connections::Combinational<int> addressRequest[NUM_CORES],
    Connections::Combinational<OUTPUT_DATATYPE> dataResponse[NUM_CORES]) {
  for (int i = 0; i < NUM_CORES; i++) {
    addressRequest[i].ResetRead();
    dataResponse[i].ResetWrite();
  }
  wait();

  while (true) {
    for (int i = 0; i < NUM_CORES; i++) {
      int address;
      if (addressRequest[i].PopNB(address)) {
        OUTPUT_DATATYPE data = mainMemory[address];
        dataResponse[i].Push(data);
      }
    }
    wait();
  }
}

void AcceleratorHarness::memoryAccessInputs() {
  memAccessBurst(inputAddressRequest, inputDataResponse);
}

void AcceleratorHarness::memoryAccessWeights() {
  memAccessBurst(weightAddressRequest, weightDataResponse);
}

void AcceleratorHarness::memoryAccessVector() {
  memAccessPack(vectorFetchAddressRequest, vectorFetchDataResponse);
}

void AcceleratorHarness::sendParams() {
  paramsIn.ResetWrite();

  wait();

  paramsIn.Push(params);
  wait();
}

void AcceleratorHarness::sendVectorParams() {
  vectorParamsIn.ResetWrite();

  wait();

  vectorParamsIn.Push(vectorParams);
  wait();
}

void AcceleratorHarness::storeOutputs() {
  for (int i = 0; i < NUM_CORES; i++) {
    vectorUnitOutput[i].ResetRead();
    outputAddress[i].ResetRead();
  }

  wait();

  while (true) {
    for (int i = 0; i < NUM_CORES; i++) {
      int address;
      if (outputAddress[i].PopNB(address)) {
        Pack1D<OUTPUT_DATATYPE, DIMENSION> data = vectorUnitOutput[i].Pop();
        // std::cout << "output @ " << address << " : " << std::endl;
        for (int j = 0; j < DIMENSION; j++) {
          mainMemory[address + j] = data[j];
          // std::cout << data[j] << " ";
        }
        // std::cout << std::endl;
      }
    }
    wait();
  }
}

void AcceleratorHarness::waitForDone() {
  for (int i = 0; i < NUM_CORES; i++) {
    done[i].ResetRead();
  }

  wait();

  for (int i = 0; i < NUM_CORES; i++) {
    done[i].SyncPop();
  }

  CCS_LOG("Accelerator Finished.");

  sc_stop();
}

void run_op(const Params params, const VectorParams vectorParams,
            INPUT_DATATYPE *mainMemory) {
  AcceleratorHarness harness("harness", params, vectorParams, mainMemory);
  sc_start();
}
